Session 1. End to End ML
Introduction
According to the National Heart, Lung and Blood Institute:
Heart disease is a catch-all phrase for a variety of conditions that affect the heart’s structure and function. Coronary heart disease is a type of heart disease that develops when the arteries of the heart cannot deliver enough oxygen-rich blood to the heart. It is the leading cause of death in the United States.
(Emphasis by me. Source: https://www.nhlbi.nih.gov/health-topics/espanol/enfermedad-coronaria)
Also, according to the World Health Organization, cardiovascular diseases are the leading cause of death globally (source: https://www.who.int/news-room/fact-sheets/detail/cardiovascular-diseases-(cvds)).
In this notebook we try to learn enough information of this topic to understand the Heart Disease UCI dataset and build simple models to predict whether a patient has a disease or not based on features like the heart rate during exercise or the cholesterol levels in the blood.
Blood and heart
Blood is very important to ensure the proper functioning of the body. Its functions cover the transport of oxygen and nutrients to the cells of the body as well as the removal of the cellular waste products.
Blood is transported to the rest of the body because it is pumped by the heart. This organ receives oxygen-poor blood and sends it to the lungs to oxygenate it. And sends the oxygen-rich blood that comes from the lugns to the rest of the body.
By josiño - Own work, Public Domain, https://commons.wikimedia.org/w/index.php?curid=9396374. Flow of the blood through the chambers of the heart. Blue arrows represent oxygen-poor blood received from the rest of the body and sent to the lungs. Red arrows represent oxygen-rich blood coming from the lungs that is sent to the rest of the body.
An inadequate supply of the blood can yield the cells to not get enough energy to function properly, causing the death of the cells in the worst case.
Coronary heart disease
The heart also needs oxygen and nutrients to function properly, these come through arteries known as coronary arteries. When we talk about a coronary disease, we often mean a difficulty of the blood flow in these arteries due to the accumulation of substances on their walls.
By NIH: National Heart, Lung and Blood Institute - http://www.nhlbi.nih.gov/health/health-topics/topics/heartattack/, Public Domain, https://commons.wikimedia.org/w/index.php?curid=25287085. Death of heart cells due to an ischemia in the coronary arteries.
In the worst case, the impact of leaving the cells of the heart without nutrients and oxygen is a heart attack, in other words, the death of part of the heart cells. This, in turn, would have an impact on the rest of the body because the pumping of the heart would be affected.
Glossary of terms
Atherosclerosis: accumulation of substances on the walls of arteries which can hinder the blood flow. Moreover, the rupture of this plaque of substances can cause the formation of a blood clot (thrombus) that, in turn, can block even more the affected area or go to other parts of the body and block those parts (embolism). (Sources: American Heart Association, Mayo Clinic)
Ischemia: blood flow reduction to a tissue. This implies a reduction of the supply of oxygen and nutrients, so cells won’t get enough energy to function properly. (Sources: American Heart Association, Mayo Clinic, Wikipedia)
Angina: chest pain due to a blood flow reduction in the coronary arteries. (Sources: United Kingdom National Health Service, (Spanish) Video sobre angina de Alberto Sanagustín)
Stable angina: angina caused by situations that requires oxygen (for example, exercise or stress) and it goes away on rest.
Unstable angina: angina that can happen even on rest.
Typican & atypical angina: typical angina usually means a chest disconfort. However, looks like some people can experience other symptoms like nausea or shortness of breath. In these cases people talk about atypical angina. (Sources: Harrington Hospital, Wikipedia).
Thrombus: blood mass in solid state that hinders the blood flow in a blood vessel. (Source: MedlinePlus)
Embolus: thrombus that detatches and goes to other parts of the body. (Source: MedlinePlus)
Acute myocardial infarction: also known as heart attack, is the death of part of the heart tissue due to an ischemia. In other words, it is the death of part of the cells due to the lack of oxygen. (Sources: Healthline, Wikipedia)
Electrocardiogram: graph record of the electric signals that causes heart beats. Each part of the record of a normal heart beat has a name, the most interesting ones for this project are the T wave and the ST segment because they can give some information about the presence of issues like an ischemia or infarction. (Sources: Mayo Clinic, Wikipedia, (Spanish) Video sobre electrocardiograma de Alberto Sanagustín, (Spanish) Serie de videos sobre el electrocardiograma normal de Alberto Sangaustín)
Nuclear stress test: a radioactive dye is injected into the patient to see the blood flow on rest and doing exercise. Moreover, during this test the activity of the heart is also measured with an electrocardiogram. (Sources: Mayo Clinic, Healthline)
Asymptomatic disease: a disease that a patient has but they experience very few or no symptoms. (Sources: (Spanish) definicion.de, MayoClinic, Wikipedia)
Left ventricular hypertrophy: thickening of the walls of the main heart chamber that pumps the blood to the rest of the body. This can cause the muscle to loose elasticity which, in turns, causes the heart to not work properly. (Sources: Mayo Clinic)
0.- Libraries required
This is the list of libraries required for this hand on:
library(caret)
library(rpart.plot)
library(tidyverse)
library(dplyr)
library(knitr)
library(ggpubr)
library(skimr)
library(ggplot2)
library(gridExtra)
library(pheatmap)
library(rsample)
library(recipes)
library(GGally)
library(visdat)
library(glmnet)
library(precrec)
library(kableExtra)
library(patchwork)
library(visdat)1.- The Data set and Exploratory data analysis (EDA)
setwd("~/Library/Mobile Documents/com~apple~CloudDocs/Master/Machine learning/sesion1")
data <- read.csv('heart_mod.csv')
# guardo el tema para poder usarlo en todas las figuras
MY_THEME <- theme(
text = element_text(family = "Roboto"),
axis.text.x = element_text(angle = 35, vjust = .6),
axis.title.x = element_blank(),
axis.ticks = element_blank(),
axis.line = element_line(colour = "grey50"),
panel.grid = element_line(color = "#b4aea9"),
panel.grid.minor = element_blank(),
panel.grid.major.x = element_blank(),
panel.grid.major.y = element_line(linetype = "dashed"),
panel.background = element_rect(fill = "#fbf9f4", color = "#fbf9f4"),
plot.background = element_rect(fill = "#fbf9f4", color = "#fbf9f4"),
legend.background = element_rect(fill = "#fbf9f4"),
plot.title = element_text(
family = "Roboto",
size = 16,
face = "bold",
color = "#2a475e",
margin = margin(b = 20)
)
)The main goal of this step is to achieve a better understanding of what information each variable contains, as well as detecting possible errors. Some common examples are:
That a column has been stored with the wrong type: a numeric variable is being recognized as text or vice versa.
That a variable contains values that do not make sense.
1.1 Variable type
There are different functions in R that help us summarize the type of
variables we have. The function glimpse,
summary, or skim.
## Rows: 303
## Columns: 15
## $ X <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18…
## $ age <int> 63, 37, 41, 56, 57, 57, 56, 44, 52, 57, 54, 48, 49, 64, 58, 5…
## $ sex <int> 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1…
## $ cp <int> 3, 2, 1, 1, 0, 0, 1, 1, 2, 2, 0, 2, 1, 3, 3, 2, 2, 3, 0, 3, 0…
## $ trestbps <int> 145, 130, 130, 120, 120, 140, 140, 120, 172, 150, 140, 130, 1…
## $ chol <int> 233, 250, 204, 236, 354, 192, 294, 263, 199, 168, 239, 275, 2…
## $ fbs <int> 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0…
## $ restecg <int> 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1…
## $ thalach <int> 150, 187, 172, 178, 163, 148, 153, 173, 162, 174, 160, 139, 1…
## $ exang <int> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0…
## $ oldpeak <dbl> 2.3, 3.5, 1.4, 0.8, 0.6, 0.4, 1.3, 0.0, 0.5, 1.6, 1.2, 0.2, 0…
## $ slope <int> 0, 0, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 0, 2, 2, 1…
## $ ca <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0…
## $ thal <int> 1, 2, 2, 2, 2, 1, 2, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3…
## $ target <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
## X age sex cp
## Min. : 1.0 Min. :29.00 Min. :0.0000 Min. :0.000
## 1st Qu.: 76.5 1st Qu.:47.50 1st Qu.:0.0000 1st Qu.:0.000
## Median :152.0 Median :55.00 Median :1.0000 Median :1.000
## Mean :152.0 Mean :54.37 Mean :0.6832 Mean :0.967
## 3rd Qu.:227.5 3rd Qu.:61.00 3rd Qu.:1.0000 3rd Qu.:2.000
## Max. :303.0 Max. :77.00 Max. :1.0000 Max. :3.000
##
## trestbps chol fbs restecg
## Min. : 94.0 Min. :126.0 Min. :0.0000 Min. :0.0000
## 1st Qu.:120.0 1st Qu.:211.0 1st Qu.:0.0000 1st Qu.:0.0000
## Median :130.0 Median :240.0 Median :0.0000 Median :1.0000
## Mean :131.6 Mean :246.3 Mean :0.1485 Mean :0.5281
## 3rd Qu.:140.0 3rd Qu.:274.5 3rd Qu.:0.0000 3rd Qu.:1.0000
## Max. :200.0 Max. :564.0 Max. :1.0000 Max. :2.0000
##
## thalach exang oldpeak slope
## Min. : 71.0 Min. :0.0000 Min. :0.00 Min. :0.000
## 1st Qu.:133.5 1st Qu.:0.0000 1st Qu.:0.00 1st Qu.:1.000
## Median :153.0 Median :0.0000 Median :0.80 Median :1.000
## Mean :149.6 Mean :0.3267 Mean :1.04 Mean :1.399
## 3rd Qu.:166.0 3rd Qu.:1.0000 3rd Qu.:1.60 3rd Qu.:2.000
## Max. :202.0 Max. :1.0000 Max. :6.20 Max. :2.000
##
## ca thal target
## Min. :0.0000 Min. :1.000 Min. :0.0000
## 1st Qu.:0.0000 1st Qu.:2.000 1st Qu.:0.0000
## Median :0.0000 Median :2.000 Median :1.0000
## Mean :0.6745 Mean :2.329 Mean :0.5446
## 3rd Qu.:1.0000 3rd Qu.:3.000 3rd Qu.:1.0000
## Max. :3.0000 Max. :3.000 Max. :1.0000
## NA's :5 NA's :2
| Name | data |
| Number of rows | 303 |
| Number of columns | 15 |
| _______________________ | |
| Column type frequency: | |
| numeric | 15 |
| ________________________ | |
| Group variables | None |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| X | 0 | 1.00 | 152.00 | 87.61 | 1 | 76.5 | 152.0 | 227.5 | 303.0 | ▇▇▇▇▇ |
| age | 0 | 1.00 | 54.37 | 9.08 | 29 | 47.5 | 55.0 | 61.0 | 77.0 | ▁▆▇▇▁ |
| sex | 0 | 1.00 | 0.68 | 0.47 | 0 | 0.0 | 1.0 | 1.0 | 1.0 | ▃▁▁▁▇ |
| cp | 0 | 1.00 | 0.97 | 1.03 | 0 | 0.0 | 1.0 | 2.0 | 3.0 | ▇▃▁▅▁ |
| trestbps | 0 | 1.00 | 131.62 | 17.54 | 94 | 120.0 | 130.0 | 140.0 | 200.0 | ▃▇▅▁▁ |
| chol | 0 | 1.00 | 246.26 | 51.83 | 126 | 211.0 | 240.0 | 274.5 | 564.0 | ▃▇▂▁▁ |
| fbs | 0 | 1.00 | 0.15 | 0.36 | 0 | 0.0 | 0.0 | 0.0 | 1.0 | ▇▁▁▁▂ |
| restecg | 0 | 1.00 | 0.53 | 0.53 | 0 | 0.0 | 1.0 | 1.0 | 2.0 | ▇▁▇▁▁ |
| thalach | 0 | 1.00 | 149.65 | 22.91 | 71 | 133.5 | 153.0 | 166.0 | 202.0 | ▁▂▅▇▂ |
| exang | 0 | 1.00 | 0.33 | 0.47 | 0 | 0.0 | 0.0 | 1.0 | 1.0 | ▇▁▁▁▃ |
| oldpeak | 0 | 1.00 | 1.04 | 1.16 | 0 | 0.0 | 0.8 | 1.6 | 6.2 | ▇▂▁▁▁ |
| slope | 0 | 1.00 | 1.40 | 0.62 | 0 | 1.0 | 1.0 | 2.0 | 2.0 | ▁▁▇▁▇ |
| ca | 5 | 0.98 | 0.67 | 0.94 | 0 | 0.0 | 0.0 | 1.0 | 3.0 | ▇▃▁▂▁ |
| thal | 2 | 0.99 | 2.33 | 0.58 | 1 | 2.0 | 2.0 | 3.0 | 3.0 | ▁▁▇▁▆ |
| target | 0 | 1.00 | 0.54 | 0.50 | 0 | 0.0 | 1.0 | 1.0 | 1.0 | ▇▁▁▁▇ |
1.2 Dataset features
This dataset is hosted on Kaggle (Heart Disease UCI), and it was from UCI Machine Learning Repository. There are records of about 300 patients from Cleveland and the features are described in a following section.
Attribute Information:
- age
- sex
- chest pain type (4 values)
- resting blood pressure
- serum cholestoral in mg/dl
- fbs: fasting blood sugar > 120 mg/dl
Hereon, variables are related to a nuclear stress test. That is, a stress test where a radioactive dye is also injected to the patient to see the blood flow:
- restecg: resting electrocardiographic results (values 0,1,2)
- thalach: maximum heart rate achieved
- exang: exercise induced angina 10.oldpeak: ST depression induced by exercise relative to rest
- slope: the slope of the peak exercise ST segment
- ca: number of major vessels (0-3) colored by flourosopy
- thal: 3 = normal; 6 = fixed defect; 7 = reversable defect
- target: 0 = yes; 1 = No
TASK TO DO:
- Remove X column.
- Transform categorical variable to R factors.
- Give (if necessary) a better name to the factor values (it will be helpful for the graphs).
# Transform categorical variable to R factors
data <- data %>%
mutate(across(c(
sex, cp, fbs, restecg, exang, slope, thal, target, ca
), as.factor))# Give a better name to the factor values for the graphs
levels(data$sex) <- c("Female", "Male")
levels(data$cp) <- c("Asymptomatic", "Atypical angina", "No angina", "Typical angina")
levels(data$fbs) <- c("No", "Yes")
levels(data$restecg) <- c("Hypertrophy", "Normal", "Abnormalities")
levels(data$exang) <- c("No", "Yes")
levels(data$slope) <- c("Descending", "Flat", "Ascending")
levels(data$thal) <- c("Fixed defect", "Normal flow", "Reversible defect")
levels(data$target) <- c("Yes", "No")Next step: Inspect all variables and make your hypotheses of how each
variable affect heart attack incidence (target column).
1.2.1 Inspect variables: target
Target variable: whether the patient has a heart disease or not
- Value 0: yes
- Value 1: no
We can see that the distribution is quite balanced. Thanks to this it wouldn’t be a bad idea using accuracy to evaluate how well the models perform.
ggplot(data, aes(target, fill=target)) +
geom_bar(width = .6) +
labs(x="Disease", y="Number of patients") +
guides(fill="none") + MY_THEME1.2.1 Inspect variables: age vs. target
Visualize how age affects the options of having or not having a heart attack.
options: density, boxplot. Check on google how to do it with ggplot.
data %>%
ggplot(aes(x = target, y = age)) +
geom_violin(width = .6) +
geom_boxplot(width = .15) +
geom_point(
aes(col = target),
position = position_jitterdodge(jitter.width = .3),
alpha = .2,
size = 3
) +
theme_pubclean() +
MY_THEME1.2.2 Inspect variables: sex vs. target
Patient sex
- Value 0: female
- Value 1: male
options: bar charts with a number of cases, bar charts with the proportion of the number of cases in each group, and bar charts separated by classes.
1.2.x Inspect variables: variable x vs. target
complete with the rest of variables.
Since in this case we only have 13 variables, it seems reasonable to go one by one.
There is also the ggpairs function to
see how the variables are related.
In the event that we have many variables, this would not be feasible.
To solve this, a useful tool is the heatmap for continuous
variables.
numeric_variables <- which(unlist(lapply(data, is.numeric)))
pheatmap(cor(data[, numeric_variables[-6]]))vars <- data %>% colnames()
my.plot <- function(var){
if (class(data[, var]) == "factor") {
p <- data %>%
ggplot(aes_string(x = var, fill = "target")) +
geom_bar(width = .5, alpha = .9) +
ggtitle(paste0(var, " vs target"))
} else {
p <- data %>%
ggplot(aes_string(x = "target", y = var)) +
geom_violin(width = .6) +
geom_boxplot(width = .15) +
geom_point(
aes_string(col = "target"),
position = position_jitterdodge(jitter.width = .3),
alpha = .2,
size = .5
) +
ggtitle(paste0(var, " vs target")) +
guides(col = "none")
}
return(p)
}
p <- lapply(vars, my.plot)
p %>%
wrap_plots(ncol = 3, guides = "collect") +
guide_area() &
MY_THEMEHipótesis extraídas de esta última figura:
age: se aprecia que la concentración de individuos con ataques de corazón es notablemente mayor a partir de los 50 años, por lo que la edad es un factor de riesgo.sex: mientras que un porcentaje relativamente pequeño de mujeres han sufrido un infarto, más de la mitad de hombres lo han tenido, por lo que ser hombre es también un factor de riesgo.cp: en cuanto al dolor de pecho, se ve cómo en la mayoría de casos los pacientes fueron asintomáticos, por lo que no parece un indicador significativo.trestbps: hay un ligero skew hacia valores elevados en el grupo con ataque, por lo que puede que sea indicativo en valores altos, pero desde luego en la mayoría de casos la presión en reposo tenía valores normales.chol: no parece que haya una diferencia en los valores de colesterol entre ambos grupos.fbs: la mayoría de individuos tenían valores inferiores a 120 mg/dL (tanto con ataque como sin), y parece que en ambos casos el ratio es parecido (~50/50).
target: 0 = yes; 1 = No
| Hypertrophy | Normal | Abnormalities | |
|---|---|---|---|
| Yes | 79 | 56 | 3 |
| No | 68 | 96 | 1 |
restecg: para valorar bien esta gráfica, he decidido ver la tabla de proporciones, ya que el número de casos con Abnormalities era muy bajo. De esta gráfica extraemos que tener hipertrtofia aumenta el riesgo de ataque frente a una situación normal, y que en el caso de tener anormalidades ese riesgo se dispara.thalach: se ve cómo individuos sin ataque al corazón alcanzan valores siginificativamente mayores de pulso cardíaco.exang: los idividuos que presentan angina inducida por ejercicio tienen una probabilidad mucho mayor de haber padecido un ataque al corazón.oldpeak: individuos con valores elevados de ST depression muchas más posibilidades de haber sufrido ataque al corazón.slope: individuos con pendiente descendente o plana tienen mucho mayor riesgo que aquellos con pendiente ascendente.ca: a mayor número de vasos coloreados en fluoroscopia, mayor riesgo.thal: tener defectos, tanto fijos como reversibles, aumenta severamente el riesgo.target: la distribución de las clases está bastante equilibrada, por lo que es buena para entrenar un modelo de clasificación.
2. Data splitting
Given a fixed amount of data, typical recommendations for splitting your data into training-test splits include 60% (training)–40% (testing), 70%–30%, or 80%–20%. Generally speaking, these are appropriate guidelines to follow; however, it is good to keep the following points in mind:
Spending too much in training (e.g., >80%) won’t allow us to get a good assessment of predictive performance. We may find a model that fits the training data very well, but is not generalizable (overfitting).
Sometimes too much spent in testing (>40%) won’t allow us to get a good assessment of model parameters.
2.1 Random Sampling
Using the library rsample and its corresponding
functions:
initial_splittrainingtesting
Remember to use the function set.seed in order to
replicate the results.
set.seed(123) #important in order to replicate
split_basico <- initial_split(data, prop = .7)
sb_train <- training(split_basico)
sb_test <- testing(split_basico)
p1 <- sb_train %>%
ggplot(aes(x = "target", fill = target)) +
geom_bar(position = "fill", width = .4) +
MY_THEME
p2 <- sb_test %>%
ggplot(aes(x = "target", fill = target)) +
geom_bar(position = "fill", width = .4) +
MY_THEME
wrap_plots(list(p1, p2), guides = "collect") # se aprecia un pequeño desbalance de las clases| Var1 | Freq |
|---|---|
| Yes | 0.4554455 |
| No | 0.5445545 |
| Var1 | Freq |
|---|---|
| Yes | 0.4433962 |
| No | 0.5566038 |
| Var1 | Freq |
|---|---|
| Yes | 0.4835165 |
| No | 0.5164835 |
Se aprecia que el balance de las clases del target no es
constante. Vamos a estratificar para garantizar que en test
y train este balance se mantenga.
2.2 Stratified Sampling
If we want to explicitly control the sampling so that our training and test sets have similar Y distributions, we can use stratified sampling.
This is more common with classification problems where the response variable may be severely imbalanced (e.g., 90% of observations with response “Yes” and 10% with response “No”).
Check the help of the function initial_split to see how
to do it.
We can use the functions table and
prop.table to check if training and test sets have similar
Y distributions.
set.seed(123) #important in order to replicate
split_strat <- initial_split(data, prop = .7, strata = "target")
strat_train <- training(split_strat)
strat_test <- testing(split_strat)
p1 <- strat_train %>%
ggplot(aes(x = "target", fill = target)) +
geom_bar(position = "fill", width = .4) +
MY_THEME
p2 <- strat_test %>%
ggplot(aes(x = "target", fill = target)) +
geom_bar(position = "fill", width = .4) +
MY_THEME
wrap_plots(list(p1, p2), guides = "collect") # clases balanceadas| Var1 | Freq |
|---|---|
| Yes | 0.4554455 |
| No | 0.5445545 |
| Var1 | Freq |
|---|---|
| Yes | 0.4549763 |
| No | 0.5450237 |
| Var1 | Freq |
|---|---|
| Yes | 0.4565217 |
| No | 0.5434783 |
Ahora sí, la proporción Yes/No es constante en todo el dataset, en test y en train.
3. Feature and Targetting engineering
3.1 Imputation of missing values
We can use the function vis_miss of the library
visdat that provides a glance ggplot of the missingness
inside a dataframe.
## [1] 7
Then, to impute missing values we can use the recipe R
package.
It has the following steps:
Function
recipe: A recipe is a description of the steps to be applied to a data set in order to prepare it for data analysis.step_impute_xxxx: creates a specification of a recipe step that will substitute missing values:step_impute_meancreates a specification of a recipe step that will substitute missing values of numeric variables by the training set mean of those variables.step_impute_knncreates a specification of a recipe step that will impute missing data using nearest neighbors. Can be applied to both numeric and categorical variables.
Preparamos una receta para imputar los NAs en todos los predictores.
prep: For a recipe with at least one preprocessing operation, estimate the required parameters from a training set that can be later applied to other data sets.
Entrenamos los modelos KNN que se van a usar para imputar, usando los datos de train.
bake: For a recipe with at least one preprocessing operation that has been trained byprep, apply the computations to new data.
Finalmente imputamos los NAs con los modelos entrenados.
datos_train_prep <- bake(trained_recipe, new_data = strat_train)
datos_test_prep <- bake(trained_recipe, new_data = strat_test)
wrap_plots(list(vis_miss(datos_train_prep, cluster = T), vis_miss(datos_test_prep)), guides = "collect", nrow = 2)Ahora vemos que no hay ningún NA.
4. Creation of the model
4.1 Logistic regression using glmnet
We are going to use the function glmnet. From its help
page:
“Fit a generalized linear model via penalized maximum likelihood. The regularization path is computed for the lasso or elasticnet penalty at a grid of values for the regularization parameter lambda. Can deal with all shapes of data, including very large sparse data matrices. Fits linear, logistic and multinomial, poisson, and Cox regression models.
4.1.1 logistic regression (\(\lambda = 0\))
4.1.2 logistic regression with penalty: lasso
set.seed(123)
fit_LogReg_cv_lasso <- cv.glmnet(x = glmnet_traindata,
y = datos_train_prep$target,
family = "binomial",
nfold=10,
alpha=1,
type.measure = "auc")
plot(fit_LogReg_cv_lasso)4.2 Logistic regression using caret.
4.2.1 By default caret
Before training the model, we need to apply the function
trainControl to specify some training parameters.
control_train <-
trainControl(
method = "cv",
# which type: boost, cv, none of them.. etc
number = 10,
# number of folds or number of resampling iterations.
returnResamp = "all",
classProbs = TRUE,
search = "grid",
savePredictions = TRUE
)Once we have established the training characteristics, with the
function train we train our model. This function “sets up a
grid of tuning parameters for a number of classification and
regression routines, fits each model and calculates a
resampling based performance measure.”
set.seed(123)
modelo_glm_caret <- train(
target ~ .,
method = "glmnet",
family = "binomial",
trControl = control_train,
data = datos_train_prep,
# train data.
metric = "Accuracy"
)
modelo_glm_caret## glmnet
##
## 211 samples
## 13 predictor
## 2 classes: 'Yes', 'No'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 190, 189, 191, 191, 189, 190, ...
## Resampling results across tuning parameters:
##
## alpha lambda Accuracy Kappa
## 0.10 0.0005010359 0.8208009 0.6329692
## 0.10 0.0050103594 0.8103463 0.6107406
## 0.10 0.0501035939 0.8246537 0.6419031
## 0.55 0.0005010359 0.8158009 0.6214351
## 0.55 0.0050103594 0.8151082 0.6200658
## 0.55 0.0501035939 0.7878788 0.5685067
## 1.00 0.0005010359 0.8158009 0.6214351
## 1.00 0.0050103594 0.8151082 0.6200658
## 1.00 0.0501035939 0.7917316 0.5756382
##
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were alpha = 0.1 and lambda = 0.05010359.
Let’s see some of the information we have in
modelo_glm_caret:
best combination of alpha and lambda:
## alpha lambda
## 3 0.1 0.05010359
The result of each fold:
performance <- modelo_glm_caret$resample
kbl(performance) %>%
kable_paper() %>%
scroll_box(width = "100%", height = "200px")| alpha | lambda | Accuracy | Kappa | Resample |
|---|---|---|---|---|
| 0.10 | 0.0501036 | 0.9047619 | 0.8000000 | Fold01 |
| 0.10 | 0.0050104 | 0.9047619 | 0.8000000 | Fold01 |
| 0.10 | 0.0005010 | 0.9047619 | 0.8000000 | Fold01 |
| 0.55 | 0.0501036 | 0.8571429 | 0.6956522 | Fold01 |
| 0.55 | 0.0050104 | 0.9047619 | 0.8000000 | Fold01 |
| 0.55 | 0.0005010 | 0.9047619 | 0.8000000 | Fold01 |
| 1.00 | 0.0501036 | 0.9047619 | 0.8000000 | Fold01 |
| 1.00 | 0.0050104 | 0.9047619 | 0.8000000 | Fold01 |
| 1.00 | 0.0005010 | 0.9047619 | 0.8000000 | Fold01 |
| 0.10 | 0.0501036 | 0.7727273 | 0.5378151 | Fold02 |
| 0.10 | 0.0050104 | 0.6818182 | 0.3529412 | Fold02 |
| 0.10 | 0.0005010 | 0.6818182 | 0.3529412 | Fold02 |
| 0.55 | 0.0501036 | 0.7272727 | 0.4500000 | Fold02 |
| 0.55 | 0.0050104 | 0.6818182 | 0.3529412 | Fold02 |
| 0.55 | 0.0005010 | 0.6818182 | 0.3529412 | Fold02 |
| 1.00 | 0.0501036 | 0.7272727 | 0.4500000 | Fold02 |
| 1.00 | 0.0050104 | 0.6818182 | 0.3529412 | Fold02 |
| 1.00 | 0.0005010 | 0.6818182 | 0.3529412 | Fold02 |
| 0.10 | 0.0501036 | 1.0000000 | 1.0000000 | Fold03 |
| 0.10 | 0.0050104 | 1.0000000 | 1.0000000 | Fold03 |
| 0.10 | 0.0005010 | 1.0000000 | 1.0000000 | Fold03 |
| 0.55 | 0.0501036 | 1.0000000 | 1.0000000 | Fold03 |
| 0.55 | 0.0050104 | 1.0000000 | 1.0000000 | Fold03 |
| 0.55 | 0.0005010 | 1.0000000 | 1.0000000 | Fold03 |
| 1.00 | 0.0501036 | 1.0000000 | 1.0000000 | Fold03 |
| 1.00 | 0.0050104 | 1.0000000 | 1.0000000 | Fold03 |
| 1.00 | 0.0005010 | 1.0000000 | 1.0000000 | Fold03 |
| 0.10 | 0.0501036 | 0.6000000 | 0.1578947 | Fold04 |
| 0.10 | 0.0050104 | 0.5500000 | 0.0425532 | Fold04 |
| 0.10 | 0.0005010 | 0.6000000 | 0.1578947 | Fold04 |
| 0.55 | 0.0501036 | 0.6500000 | 0.2553191 | Fold04 |
| 0.55 | 0.0050104 | 0.5500000 | 0.0425532 | Fold04 |
| 0.55 | 0.0005010 | 0.5500000 | 0.0425532 | Fold04 |
| 1.00 | 0.0501036 | 0.6500000 | 0.2553191 | Fold04 |
| 1.00 | 0.0050104 | 0.5500000 | 0.0425532 | Fold04 |
| 1.00 | 0.0005010 | 0.5500000 | 0.0425532 | Fold04 |
| 0.10 | 0.0501036 | 0.7727273 | 0.5378151 | Fold05 |
| 0.10 | 0.0050104 | 0.7727273 | 0.5378151 | Fold05 |
| 0.10 | 0.0005010 | 0.7272727 | 0.4406780 | Fold05 |
| 0.55 | 0.0501036 | 0.7272727 | 0.4500000 | Fold05 |
| 0.55 | 0.0050104 | 0.7727273 | 0.5378151 | Fold05 |
| 0.55 | 0.0005010 | 0.7272727 | 0.4406780 | Fold05 |
| 1.00 | 0.0501036 | 0.7727273 | 0.5378151 | Fold05 |
| 1.00 | 0.0050104 | 0.7727273 | 0.5378151 | Fold05 |
| 1.00 | 0.0005010 | 0.7272727 | 0.4406780 | Fold05 |
| 0.10 | 0.0501036 | 0.7142857 | 0.4220183 | Fold06 |
| 0.10 | 0.0050104 | 0.6666667 | 0.3287671 | Fold06 |
| 0.10 | 0.0005010 | 0.6666667 | 0.3287671 | Fold06 |
| 0.55 | 0.0501036 | 0.6666667 | 0.3225806 | Fold06 |
| 0.55 | 0.0050104 | 0.7142857 | 0.4220183 | Fold06 |
| 0.55 | 0.0005010 | 0.6666667 | 0.3287671 | Fold06 |
| 1.00 | 0.0501036 | 0.6666667 | 0.3225806 | Fold06 |
| 1.00 | 0.0050104 | 0.7142857 | 0.4220183 | Fold06 |
| 1.00 | 0.0005010 | 0.6666667 | 0.3287671 | Fold06 |
| 0.10 | 0.0501036 | 0.8500000 | 0.7000000 | Fold07 |
| 0.10 | 0.0050104 | 0.8500000 | 0.6938776 | Fold07 |
| 0.10 | 0.0005010 | 0.9500000 | 0.8979592 | Fold07 |
| 0.55 | 0.0501036 | 0.8500000 | 0.7000000 | Fold07 |
| 0.55 | 0.0050104 | 0.8500000 | 0.6938776 | Fold07 |
| 0.55 | 0.0005010 | 0.9500000 | 0.8979592 | Fold07 |
| 1.00 | 0.0501036 | 0.7500000 | 0.4897959 | Fold07 |
| 1.00 | 0.0050104 | 0.8500000 | 0.6938776 | Fold07 |
| 1.00 | 0.0005010 | 0.9500000 | 0.8979592 | Fold07 |
| 0.10 | 0.0501036 | 0.8636364 | 0.7226891 | Fold08 |
| 0.10 | 0.0050104 | 0.8636364 | 0.7226891 | Fold08 |
| 0.10 | 0.0005010 | 0.8636364 | 0.7226891 | Fold08 |
| 0.55 | 0.0501036 | 0.8181818 | 0.6333333 | Fold08 |
| 0.55 | 0.0050104 | 0.8636364 | 0.7226891 | Fold08 |
| 0.55 | 0.0005010 | 0.8636364 | 0.7226891 | Fold08 |
| 1.00 | 0.0501036 | 0.8636364 | 0.7226891 | Fold08 |
| 1.00 | 0.0050104 | 0.8636364 | 0.7226891 | Fold08 |
| 1.00 | 0.0005010 | 0.8636364 | 0.7226891 | Fold08 |
| 0.10 | 0.0501036 | 0.9047619 | 0.8090909 | Fold09 |
| 0.10 | 0.0050104 | 0.9047619 | 0.8090909 | Fold09 |
| 0.10 | 0.0005010 | 0.9047619 | 0.8090909 | Fold09 |
| 0.55 | 0.0501036 | 0.8095238 | 0.6181818 | Fold09 |
| 0.55 | 0.0050104 | 0.9047619 | 0.8090909 | Fold09 |
| 0.55 | 0.0005010 | 0.9047619 | 0.8090909 | Fold09 |
| 1.00 | 0.0501036 | 0.8095238 | 0.6181818 | Fold09 |
| 1.00 | 0.0050104 | 0.9047619 | 0.8090909 | Fold09 |
| 1.00 | 0.0005010 | 0.9047619 | 0.8090909 | Fold09 |
| 0.10 | 0.0501036 | 0.8636364 | 0.7317073 | Fold10 |
| 0.10 | 0.0050104 | 0.9090909 | 0.8196721 | Fold10 |
| 0.10 | 0.0005010 | 0.9090909 | 0.8196721 | Fold10 |
| 0.55 | 0.0501036 | 0.7727273 | 0.5600000 | Fold10 |
| 0.55 | 0.0050104 | 0.9090909 | 0.8196721 | Fold10 |
| 0.55 | 0.0005010 | 0.9090909 | 0.8196721 | Fold10 |
| 1.00 | 0.0501036 | 0.7727273 | 0.5600000 | Fold10 |
| 1.00 | 0.0050104 | 0.9090909 | 0.8196721 | Fold10 |
| 1.00 | 0.0005010 | 0.9090909 | 0.8196721 | Fold10 |
We can plot the results:
performance$alpha <- as.factor(performance$alpha)
performance$lambda <- as.factor(performance$lambda)
ggplot(data = performance, aes(x = alpha, y = Accuracy,color=lambda)) +
geom_boxplot() +
geom_point(position=position_jitterdodge())+
labs(x = "") +
theme_bw() 4.2.2 tunning caret
control_train <-
trainControl(
method = "cv",
# which type: boost, cv, none of them.. etc
number = 10,
# number of folds or number of resampling iterations.
returnResamp = "all",
classProbs = TRUE,
summaryFunction = twoClassSummary,
#a function to compute performance metrics across resamples.
search = "grid",
savePredictions = TRUE
)How can we do a custom hyperparameter search? With the help of the
function expand.grid and with the parameter
tuneGrid of the function train:
lambda <- c(0,0.01, 0.1)
alpha <- c(0,0.1,0.3,0.5, 0.9, 1)
hyper_grid <- expand.grid(alpha = alpha, lambda = lambda)
set.seed(123)
modelo_glm_caret_grid <- train(
target ~ .,
method = "glmnet",
family = "binomial",
trControl = control_train,
data = datos_train_prep,
tuneGrid = hyper_grid,
metric = "ROC"
)
modelo_glm_caret_grid## glmnet
##
## 211 samples
## 13 predictor
## 2 classes: 'Yes', 'No'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 190, 189, 191, 191, 189, 190, ...
## Resampling results across tuning parameters:
##
## alpha lambda ROC Sens Spec
## 0.0 0.00 0.8829461 0.7588889 0.8780303
## 0.0 0.01 0.8829461 0.7588889 0.8780303
## 0.0 0.10 0.8871886 0.7700000 0.8613636
## 0.1 0.00 0.8674411 0.7500000 0.8787879
## 0.1 0.01 0.8737542 0.7377778 0.8787879
## 0.1 0.10 0.8871801 0.7600000 0.8530303
## 0.3 0.00 0.8666077 0.7500000 0.8787879
## 0.3 0.01 0.8754966 0.7377778 0.8787879
## 0.3 0.10 0.8774579 0.7500000 0.8363636
## 0.5 0.00 0.8674411 0.7500000 0.8787879
## 0.5 0.01 0.8764057 0.7377778 0.8696970
## 0.5 0.10 0.8680640 0.7277778 0.8363636
## 0.9 0.00 0.8674411 0.7500000 0.8787879
## 0.9 0.01 0.8727189 0.7477778 0.8613636
## 0.9 0.10 0.8305724 0.7066667 0.8000000
## 1.0 0.00 0.8674411 0.7500000 0.8787879
## 1.0 0.01 0.8736448 0.7477778 0.8613636
## 1.0 0.10 0.8277189 0.6855556 0.8000000
##
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were alpha = 0 and lambda = 0.1.
## alpha lambda
## 3 0 0.1
performance <- modelo_glm_caret_grid$resample
kbl(performance) %>%
kable_paper() %>%
scroll_box(width = "100%", height = "200px")| alpha | lambda | ROC | Sens | Spec | Resample |
|---|---|---|---|---|---|
| 0.0 | 0.10 | 0.9537037 | 0.7777778 | 1.0000000 | Fold01 |
| 0.0 | 0.00 | 0.9537037 | 0.7777778 | 1.0000000 | Fold01 |
| 0.0 | 0.01 | 0.9537037 | 0.7777778 | 1.0000000 | Fold01 |
| 0.1 | 0.10 | 0.9629630 | 0.7777778 | 1.0000000 | Fold01 |
| 0.1 | 0.00 | 0.9537037 | 0.7777778 | 1.0000000 | Fold01 |
| 0.1 | 0.01 | 0.9537037 | 0.7777778 | 1.0000000 | Fold01 |
| 0.3 | 0.10 | 0.9629630 | 0.8888889 | 1.0000000 | Fold01 |
| 0.3 | 0.00 | 0.9537037 | 0.7777778 | 1.0000000 | Fold01 |
| 0.3 | 0.01 | 0.9537037 | 0.7777778 | 1.0000000 | Fold01 |
| 0.5 | 0.10 | 0.9629630 | 0.7777778 | 1.0000000 | Fold01 |
| 0.5 | 0.00 | 0.9537037 | 0.7777778 | 1.0000000 | Fold01 |
| 0.5 | 0.01 | 0.9537037 | 0.7777778 | 1.0000000 | Fold01 |
| 0.9 | 0.10 | 0.9537037 | 0.6666667 | 1.0000000 | Fold01 |
| 0.9 | 0.00 | 0.9537037 | 0.7777778 | 1.0000000 | Fold01 |
| 0.9 | 0.01 | 0.9537037 | 0.7777778 | 1.0000000 | Fold01 |
| 1.0 | 0.10 | 0.9537037 | 0.5555556 | 1.0000000 | Fold01 |
| 1.0 | 0.00 | 0.9537037 | 0.7777778 | 1.0000000 | Fold01 |
| 1.0 | 0.01 | 0.9629630 | 0.7777778 | 1.0000000 | Fold01 |
| 0.0 | 0.10 | 0.7500000 | 0.7000000 | 0.7500000 | Fold02 |
| 0.0 | 0.00 | 0.7333333 | 0.7000000 | 0.8333333 | Fold02 |
| 0.0 | 0.01 | 0.7333333 | 0.7000000 | 0.8333333 | Fold02 |
| 0.1 | 0.10 | 0.7416667 | 0.7000000 | 0.7500000 | Fold02 |
| 0.1 | 0.00 | 0.7166667 | 0.6000000 | 0.7500000 | Fold02 |
| 0.1 | 0.01 | 0.7083333 | 0.6000000 | 0.7500000 | Fold02 |
| 0.3 | 0.10 | 0.7250000 | 0.7000000 | 0.7500000 | Fold02 |
| 0.3 | 0.00 | 0.7166667 | 0.6000000 | 0.7500000 | Fold02 |
| 0.3 | 0.01 | 0.7166667 | 0.6000000 | 0.7500000 | Fold02 |
| 0.5 | 0.10 | 0.7166667 | 0.7000000 | 0.7500000 | Fold02 |
| 0.5 | 0.00 | 0.7166667 | 0.6000000 | 0.7500000 | Fold02 |
| 0.5 | 0.01 | 0.7166667 | 0.6000000 | 0.7500000 | Fold02 |
| 0.9 | 0.10 | 0.6916667 | 0.7000000 | 0.7500000 | Fold02 |
| 0.9 | 0.00 | 0.7166667 | 0.6000000 | 0.7500000 | Fold02 |
| 0.9 | 0.01 | 0.7083333 | 0.7000000 | 0.7500000 | Fold02 |
| 1.0 | 0.10 | 0.7083333 | 0.7000000 | 0.7500000 | Fold02 |
| 1.0 | 0.00 | 0.7166667 | 0.6000000 | 0.7500000 | Fold02 |
| 1.0 | 0.01 | 0.7083333 | 0.7000000 | 0.7500000 | Fold02 |
| 0.0 | 0.10 | 1.0000000 | 1.0000000 | 1.0000000 | Fold03 |
| 0.0 | 0.00 | 1.0000000 | 1.0000000 | 1.0000000 | Fold03 |
| 0.0 | 0.01 | 1.0000000 | 1.0000000 | 1.0000000 | Fold03 |
| 0.1 | 0.10 | 1.0000000 | 1.0000000 | 1.0000000 | Fold03 |
| 0.1 | 0.00 | 1.0000000 | 1.0000000 | 1.0000000 | Fold03 |
| 0.1 | 0.01 | 1.0000000 | 1.0000000 | 1.0000000 | Fold03 |
| 0.3 | 0.10 | 1.0000000 | 1.0000000 | 1.0000000 | Fold03 |
| 0.3 | 0.00 | 1.0000000 | 1.0000000 | 1.0000000 | Fold03 |
| 0.3 | 0.01 | 1.0000000 | 1.0000000 | 1.0000000 | Fold03 |
| 0.5 | 0.10 | 1.0000000 | 1.0000000 | 1.0000000 | Fold03 |
| 0.5 | 0.00 | 1.0000000 | 1.0000000 | 1.0000000 | Fold03 |
| 0.5 | 0.01 | 1.0000000 | 1.0000000 | 1.0000000 | Fold03 |
| 0.9 | 0.10 | 1.0000000 | 1.0000000 | 1.0000000 | Fold03 |
| 0.9 | 0.00 | 1.0000000 | 1.0000000 | 1.0000000 | Fold03 |
| 0.9 | 0.01 | 1.0000000 | 1.0000000 | 1.0000000 | Fold03 |
| 1.0 | 0.10 | 1.0000000 | 1.0000000 | 1.0000000 | Fold03 |
| 1.0 | 0.00 | 1.0000000 | 1.0000000 | 1.0000000 | Fold03 |
| 1.0 | 0.01 | 1.0000000 | 1.0000000 | 1.0000000 | Fold03 |
| 0.0 | 0.10 | 0.8080808 | 0.3333333 | 0.8181818 | Fold04 |
| 0.0 | 0.00 | 0.7878788 | 0.3333333 | 0.8181818 | Fold04 |
| 0.0 | 0.01 | 0.7878788 | 0.3333333 | 0.8181818 | Fold04 |
| 0.1 | 0.10 | 0.8080808 | 0.3333333 | 0.9090909 | Fold04 |
| 0.1 | 0.00 | 0.6666667 | 0.3333333 | 0.8181818 | Fold04 |
| 0.1 | 0.01 | 0.7272727 | 0.2222222 | 0.8181818 | Fold04 |
| 0.3 | 0.10 | 0.7777778 | 0.3333333 | 0.9090909 | Fold04 |
| 0.3 | 0.00 | 0.6666667 | 0.3333333 | 0.8181818 | Fold04 |
| 0.3 | 0.01 | 0.7272727 | 0.2222222 | 0.8181818 | Fold04 |
| 0.5 | 0.10 | 0.7373737 | 0.3333333 | 0.9090909 | Fold04 |
| 0.5 | 0.00 | 0.6666667 | 0.3333333 | 0.8181818 | Fold04 |
| 0.5 | 0.01 | 0.7272727 | 0.2222222 | 0.8181818 | Fold04 |
| 0.9 | 0.10 | 0.7070707 | 0.3333333 | 0.7272727 | Fold04 |
| 0.9 | 0.00 | 0.6666667 | 0.3333333 | 0.8181818 | Fold04 |
| 0.9 | 0.01 | 0.7070707 | 0.2222222 | 0.8181818 | Fold04 |
| 1.0 | 0.10 | 0.7070707 | 0.3333333 | 0.7272727 | Fold04 |
| 1.0 | 0.00 | 0.6666667 | 0.3333333 | 0.8181818 | Fold04 |
| 1.0 | 0.01 | 0.7070707 | 0.2222222 | 0.8181818 | Fold04 |
| 0.0 | 0.10 | 0.8166667 | 0.7000000 | 0.8333333 | Fold05 |
| 0.0 | 0.00 | 0.8250000 | 0.7000000 | 0.8333333 | Fold05 |
| 0.0 | 0.01 | 0.8250000 | 0.7000000 | 0.8333333 | Fold05 |
| 0.1 | 0.10 | 0.8166667 | 0.7000000 | 0.8333333 | Fold05 |
| 0.1 | 0.00 | 0.8166667 | 0.6000000 | 0.8333333 | Fold05 |
| 0.1 | 0.01 | 0.8166667 | 0.7000000 | 0.8333333 | Fold05 |
| 0.3 | 0.10 | 0.7833333 | 0.7000000 | 0.7500000 | Fold05 |
| 0.3 | 0.00 | 0.8083333 | 0.6000000 | 0.8333333 | Fold05 |
| 0.3 | 0.01 | 0.8166667 | 0.7000000 | 0.8333333 | Fold05 |
| 0.5 | 0.10 | 0.7666667 | 0.7000000 | 0.7500000 | Fold05 |
| 0.5 | 0.00 | 0.8166667 | 0.6000000 | 0.8333333 | Fold05 |
| 0.5 | 0.01 | 0.8166667 | 0.7000000 | 0.8333333 | Fold05 |
| 0.9 | 0.10 | 0.7583333 | 0.7000000 | 0.7500000 | Fold05 |
| 0.9 | 0.00 | 0.8166667 | 0.6000000 | 0.8333333 | Fold05 |
| 0.9 | 0.01 | 0.8083333 | 0.7000000 | 0.8333333 | Fold05 |
| 1.0 | 0.10 | 0.7416667 | 0.6000000 | 0.6666667 | Fold05 |
| 1.0 | 0.00 | 0.8166667 | 0.6000000 | 0.8333333 | Fold05 |
| 1.0 | 0.01 | 0.8083333 | 0.7000000 | 0.8333333 | Fold05 |
| 0.0 | 0.10 | 0.7545455 | 0.6000000 | 0.8181818 | Fold06 |
| 0.0 | 0.00 | 0.7363636 | 0.6000000 | 0.8181818 | Fold06 |
| 0.0 | 0.01 | 0.7363636 | 0.6000000 | 0.8181818 | Fold06 |
| 0.1 | 0.10 | 0.7636364 | 0.6000000 | 0.8181818 | Fold06 |
| 0.1 | 0.00 | 0.7181818 | 0.6000000 | 0.7272727 | Fold06 |
| 0.1 | 0.01 | 0.7181818 | 0.6000000 | 0.8181818 | Fold06 |
| 0.3 | 0.10 | 0.7818182 | 0.5000000 | 0.8181818 | Fold06 |
| 0.3 | 0.00 | 0.7181818 | 0.6000000 | 0.7272727 | Fold06 |
| 0.3 | 0.01 | 0.7181818 | 0.6000000 | 0.8181818 | Fold06 |
| 0.5 | 0.10 | 0.7909091 | 0.5000000 | 0.8181818 | Fold06 |
| 0.5 | 0.00 | 0.7181818 | 0.6000000 | 0.7272727 | Fold06 |
| 0.5 | 0.01 | 0.7272727 | 0.6000000 | 0.8181818 | Fold06 |
| 0.9 | 0.10 | 0.7090909 | 0.4000000 | 0.7272727 | Fold06 |
| 0.9 | 0.00 | 0.7181818 | 0.6000000 | 0.7272727 | Fold06 |
| 0.9 | 0.01 | 0.7272727 | 0.6000000 | 0.8181818 | Fold06 |
| 1.0 | 0.10 | 0.7090909 | 0.4000000 | 0.7272727 | Fold06 |
| 1.0 | 0.00 | 0.7181818 | 0.6000000 | 0.7272727 | Fold06 |
| 1.0 | 0.01 | 0.7272727 | 0.6000000 | 0.8181818 | Fold06 |
| 0.0 | 0.10 | 0.9191919 | 0.8888889 | 0.8181818 | Fold07 |
| 0.0 | 0.00 | 0.9393939 | 0.7777778 | 0.8181818 | Fold07 |
| 0.0 | 0.01 | 0.9393939 | 0.7777778 | 0.8181818 | Fold07 |
| 0.1 | 0.10 | 0.9090909 | 0.8888889 | 0.8181818 | Fold07 |
| 0.1 | 0.00 | 0.9494949 | 0.8888889 | 1.0000000 | Fold07 |
| 0.1 | 0.01 | 0.9595960 | 0.7777778 | 0.9090909 | Fold07 |
| 0.3 | 0.10 | 0.8989899 | 0.7777778 | 0.8181818 | Fold07 |
| 0.3 | 0.00 | 0.9494949 | 0.8888889 | 1.0000000 | Fold07 |
| 0.3 | 0.01 | 0.9595960 | 0.7777778 | 0.9090909 | Fold07 |
| 0.5 | 0.10 | 0.8787879 | 0.6666667 | 0.8181818 | Fold07 |
| 0.5 | 0.00 | 0.9494949 | 0.8888889 | 1.0000000 | Fold07 |
| 0.5 | 0.01 | 0.9595960 | 0.7777778 | 0.8181818 | Fold07 |
| 0.9 | 0.10 | 0.7676768 | 0.6666667 | 0.7272727 | Fold07 |
| 0.9 | 0.00 | 0.9494949 | 0.8888889 | 1.0000000 | Fold07 |
| 0.9 | 0.01 | 0.9595960 | 0.7777778 | 0.8181818 | Fold07 |
| 1.0 | 0.10 | 0.7474747 | 0.6666667 | 0.6363636 | Fold07 |
| 1.0 | 0.00 | 0.9494949 | 0.8888889 | 1.0000000 | Fold07 |
| 1.0 | 0.01 | 0.9595960 | 0.7777778 | 0.8181818 | Fold07 |
| 0.0 | 0.10 | 0.9666667 | 0.8000000 | 0.8333333 | Fold08 |
| 0.0 | 0.00 | 0.9583333 | 0.8000000 | 0.9166667 | Fold08 |
| 0.0 | 0.01 | 0.9583333 | 0.8000000 | 0.9166667 | Fold08 |
| 0.1 | 0.10 | 0.9666667 | 0.8000000 | 0.8333333 | Fold08 |
| 0.1 | 0.00 | 0.9666667 | 0.8000000 | 0.9166667 | Fold08 |
| 0.1 | 0.01 | 0.9583333 | 0.8000000 | 0.9166667 | Fold08 |
| 0.3 | 0.10 | 0.9666667 | 0.8000000 | 0.9166667 | Fold08 |
| 0.3 | 0.00 | 0.9666667 | 0.8000000 | 0.9166667 | Fold08 |
| 0.3 | 0.01 | 0.9583333 | 0.8000000 | 0.9166667 | Fold08 |
| 0.5 | 0.10 | 0.9666667 | 0.8000000 | 0.9166667 | Fold08 |
| 0.5 | 0.00 | 0.9666667 | 0.8000000 | 0.9166667 | Fold08 |
| 0.5 | 0.01 | 0.9583333 | 0.8000000 | 0.9166667 | Fold08 |
| 0.9 | 0.10 | 0.8666667 | 0.8000000 | 0.9166667 | Fold08 |
| 0.9 | 0.00 | 0.9666667 | 0.8000000 | 0.9166667 | Fold08 |
| 0.9 | 0.01 | 0.9583333 | 0.8000000 | 0.9166667 | Fold08 |
| 1.0 | 0.10 | 0.8583333 | 0.8000000 | 0.9166667 | Fold08 |
| 1.0 | 0.00 | 0.9666667 | 0.8000000 | 0.9166667 | Fold08 |
| 1.0 | 0.01 | 0.9583333 | 0.8000000 | 0.9166667 | Fold08 |
| 0.0 | 0.10 | 0.9363636 | 0.9000000 | 0.9090909 | Fold09 |
| 0.0 | 0.00 | 0.9454545 | 0.9000000 | 0.9090909 | Fold09 |
| 0.0 | 0.01 | 0.9454545 | 0.9000000 | 0.9090909 | Fold09 |
| 0.1 | 0.10 | 0.9363636 | 0.8000000 | 0.8181818 | Fold09 |
| 0.1 | 0.00 | 0.9363636 | 0.9000000 | 0.9090909 | Fold09 |
| 0.1 | 0.01 | 0.9454545 | 0.9000000 | 0.9090909 | Fold09 |
| 0.3 | 0.10 | 0.9363636 | 0.8000000 | 0.8181818 | Fold09 |
| 0.3 | 0.00 | 0.9363636 | 0.9000000 | 0.9090909 | Fold09 |
| 0.3 | 0.01 | 0.9545455 | 0.9000000 | 0.9090909 | Fold09 |
| 0.5 | 0.10 | 0.9272727 | 0.8000000 | 0.8181818 | Fold09 |
| 0.5 | 0.00 | 0.9363636 | 0.9000000 | 0.9090909 | Fold09 |
| 0.5 | 0.01 | 0.9545455 | 0.9000000 | 0.9090909 | Fold09 |
| 0.9 | 0.10 | 0.9181818 | 0.8000000 | 0.8181818 | Fold09 |
| 0.9 | 0.00 | 0.9363636 | 0.9000000 | 0.9090909 | Fold09 |
| 0.9 | 0.01 | 0.9545455 | 0.9000000 | 0.9090909 | Fold09 |
| 1.0 | 0.10 | 0.9181818 | 0.8000000 | 0.9090909 | Fold09 |
| 1.0 | 0.00 | 0.9363636 | 0.9000000 | 0.9090909 | Fold09 |
| 1.0 | 0.01 | 0.9545455 | 0.9000000 | 0.9090909 | Fold09 |
| 0.0 | 0.10 | 0.9666667 | 1.0000000 | 0.8333333 | Fold10 |
| 0.0 | 0.00 | 0.9500000 | 1.0000000 | 0.8333333 | Fold10 |
| 0.0 | 0.01 | 0.9500000 | 1.0000000 | 0.8333333 | Fold10 |
| 0.1 | 0.10 | 0.9666667 | 1.0000000 | 0.7500000 | Fold10 |
| 0.1 | 0.00 | 0.9500000 | 1.0000000 | 0.8333333 | Fold10 |
| 0.1 | 0.01 | 0.9500000 | 1.0000000 | 0.8333333 | Fold10 |
| 0.3 | 0.10 | 0.9416667 | 1.0000000 | 0.5833333 | Fold10 |
| 0.3 | 0.00 | 0.9500000 | 1.0000000 | 0.8333333 | Fold10 |
| 0.3 | 0.01 | 0.9500000 | 1.0000000 | 0.8333333 | Fold10 |
| 0.5 | 0.10 | 0.9333333 | 1.0000000 | 0.5833333 | Fold10 |
| 0.5 | 0.00 | 0.9500000 | 1.0000000 | 0.8333333 | Fold10 |
| 0.5 | 0.01 | 0.9500000 | 1.0000000 | 0.8333333 | Fold10 |
| 0.9 | 0.10 | 0.9333333 | 1.0000000 | 0.5833333 | Fold10 |
| 0.9 | 0.00 | 0.9500000 | 1.0000000 | 0.8333333 | Fold10 |
| 0.9 | 0.01 | 0.9500000 | 1.0000000 | 0.7500000 | Fold10 |
| 1.0 | 0.10 | 0.9333333 | 1.0000000 | 0.6666667 | Fold10 |
| 1.0 | 0.00 | 0.9500000 | 1.0000000 | 0.8333333 | Fold10 |
| 1.0 | 0.01 | 0.9500000 | 1.0000000 | 0.7500000 | Fold10 |
performance$alpha <- as.factor(performance$alpha)
performance$lambda <- as.factor(performance$lambda)
ggplot(data = performance, aes(x = alpha, y = ROC,color=lambda)) +
geom_boxplot() +
geom_point(position=position_jitterdodge())+
labs(x = "") +
theme_bw() 5. Test and compare models
Use the library precrec to compare the models.
5.1 train
#for models created with glmnet
pred_fit2 <- predict(
object = fit_logistic_regression,
newx = data.matrix(datos_train_prep[, -14])
)
pred_fit3 <- predict(
object = fit_LogReg_cv_lasso,
newx = data.matrix(datos_train_prep[, -14]),
s = "lambda.min"
)
pred_fit4 <- predict(
object = fit_LogReg_cv_ridge,
newx = data.matrix(datos_train_prep[, -14]),
s = "lambda.min"
)
pred_fit5 <- predict(
object = fit_LogReg_cv_en,
newx = data.matrix(datos_train_prep[, -14]),
s = "lambda.min"
)
#for models created with caret:
pred_caret <- predict(
object = modelo_glm_caret,
newdata = datos_train_prep[-14],
type = "prob"
)
pred_fit6 <- pred_caret$No # la clase 0 en este caso es que si tenga ataque
pred_caret <- predict(
object = modelo_glm_caret_grid,
newdata = datos_train_prep[-14],
type = "prob"
)
pred_fit7 <- pred_caret$No labels <- (as.vector(datos_train_prep$target) == "No") + 0 #we are predicting not to have a heart attack
mis_models <- mmdata(
list(
as.vector(pred_fit2),
as.vector(pred_fit3),
as.vector(pred_fit4),
as.vector(pred_fit5),
as.vector(pred_fit6),
as.vector(pred_fit7)
),
labels,
modnames = c(
"logistic_regression",
"lasso",
"ridge",
"elastic-net",
"caret",
"caret_grid"
)
)
auroc <- evalmod(mis_models)
autoplot(auroc)5.2 test
Predictions:
#for models created with glmnet
pred_fit2 <- predict(
object = fit_logistic_regression,
newx = data.matrix(datos_test_prep[, -14])
)
pred_fit3 <- predict(
object = fit_LogReg_cv_lasso,
newx = data.matrix(datos_test_prep[, -14]),
s = "lambda.min"
)
pred_fit4 <- predict(
object = fit_LogReg_cv_ridge,
newx = data.matrix(datos_test_prep[, -14]),
s = "lambda.min"
)
pred_fit5 <- predict(
object = fit_LogReg_cv_en,
newx = data.matrix(datos_test_prep[, -14]),
s = "lambda.min"
)
#for models created with caret:
pred_caret <- predict(
object = modelo_glm_caret,
newdata = datos_test_prep[-14],
type = "prob"
)
pred_fit6 <- pred_caret$No # la clase 0 en este caso es que si tenga ataque
pred_caret <- predict(
object = modelo_glm_caret_grid,
newdata = datos_test_prep[-14],
type = "prob"
)
pred_fit7 <- pred_caret$No Plots final:
labels <- (as.vector(datos_test_prep$target) == "No") + 0 #we are predicting not to have a heart attack
mis_models <- mmdata(
list(
as.vector(pred_fit2),
as.vector(pred_fit3),
as.vector(pred_fit4),
as.vector(pred_fit5),
as.vector(pred_fit6),
as.vector(pred_fit7)
),
labels,
modnames = c(
"logistic_regression",
"lasso",
"ridge",
"elastic-net",
"caret",
"caret_grid"
)
)
auroc <- evalmod(mis_models)
autoplot(auroc)